# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

import logging
from ollama import Client

from kag.interface import LLMClient


# logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


@LLMClient.register("Ollama")
@LLMClient.register("ollama")
class OllamaClient(LLMClient):
    """
    A client class for interacting with the Ollama API.

    This class provides methods to make synchronous requests to the Ollama API, handle model calls, and parse responses.
    """

    def __init__(
        self,
        model: str,
        base_url: str,
        timeout: float = None,
    ):
        """
        Initializes the OllamaClient instance.

        Args:
            model (str): The model to use for requests.
            base_url (str): The base URL for the Ollama API.
            timeout (float): The timeout duration for the service request. Defaults to None, means no timeout.
        """
        self.model = model
        self.base_url = base_url
        self.timeout = timeout
        self.param = {}
        self.client = Client(host=self.base_url, timeout=self.timeout)
        self.check()

    def sync_request(self, prompt, image=None):
        """
        Makes a synchronous request to the Ollama API with the given prompt.

        Args:
            prompt: The prompt to send to the Ollama API.
            image: Optional image data to include in the request.

        Returns:
            str: The content of the response from the Ollama API.
        """
        response = self.client.generate(model=self.model, prompt=prompt, stream=False)
        content = response["response"]
        content = content.replace("&rdquo;", "”").replace("&ldquo;", "“")
        content = content.replace("&middot;", "")

        return content

    def __call__(self, prompt, image=None):
        """
        Executes a model request when the object is called and returns the result.

        Parameters:
            prompt (str): The prompt provided to the model.

        Returns:
            str: The response content generated by the model.
        """

        return self.sync_request(prompt, image)
